{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RZFyZYqx74G5"
      },
      "outputs": [],
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gLeUAXP5j8RJ"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "#tf.config.experimental.set_visible_devices([], \"GPU\")\n",
        "\n",
        "from colabtools import adhoc_import\n",
        "import importlib\n",
        "from simulation_research.diffusion import ode_datasets\n",
        "from simulation_research.diffusion import diffusion_unet\n",
        "from simulation_research.diffusion import samplers\n",
        "from simulation_research.diffusion import train\n",
        "importlib.reload(ode_datasets)\n",
        "importlib.reload(diffusion_unet)\n",
        "importlib.reload(samplers)\n",
        "importlib.reload(train)\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib import rc\n",
        "rc('animation', html='jshtml')\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "import jax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RQwCeAGHkB3i"
      },
      "outputs": [],
      "source": [
        "dt = .5\n",
        "bs = 400\n",
        "ds = ode_datasets.NPendulum(N=4000+bs,n=2,dt=dt)\n",
        "\n",
        "\n",
        "thetas,vs = ode_datasets.unpack(ds.Zs[bs:])\n",
        "test_x = ode_datasets.unpack(ds.Zs[:bs])[0]\n",
        "#thetas /=thetas.std()\n",
        "#thetas = jax.random.normal(jax.random.PRNGKey(38),thetas.shape)\n",
        "dataset = tf.data.Dataset.from_tensor_slices(thetas)\n",
        "\n",
        "dataiter = dataset.shuffle(len(dataset)).batch(bs).as_numpy_iterator"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FYsRAhjYp3Y1"
      },
      "outputs": [],
      "source": [
        "ds.animate()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JbFiGz6W8Qgc"
      },
      "outputs": [],
      "source": [
        "x = test_x#next(dataiter())\n",
        "t = np.random.rand(x.shape[0])\n",
        "model = diffusion_unet.UNet(diffusion_unet.unet_64_config(out_dim=x.shape[-1],base_channels=24))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q1KyXGGrQ28E"
      },
      "outputs": [],
      "source": [
        "from jax import jit,vmap\n",
        "@jit\n",
        "def rel_err(x,y):\n",
        "  return  jnp.abs(x-y).sum(-1)/(jnp.abs(x).sum(-1)+jnp.abs(y).sum(-1))\n",
        "\n",
        "\n",
        "kstart=10\n",
        "@jit\n",
        "def log_prediction_metric(qs):\n",
        "  k=kstart\n",
        "  q = qs[k:]\n",
        "  v = -(q[:-2]-q[2:])/(2*(ds.T[1]-ds.T[0]))\n",
        "  #print(vmap(ds.M)(q[1:-1]).shape,v.shape)\n",
        "  z = ode_datasets.pack(q[1:-1],(vmap(ds.M)(q[1:-1])@v[...,None]).squeeze(-1))\n",
        "  T = ds.T_long[k+1:-1]\n",
        "  z_gt = ds.integrate(z[0],T)\n",
        "  return jnp.log(rel_err(z,z_gt)[1:len(T)//3]).mean()\n",
        "\n",
        "@jit\n",
        "def pmetric(qs):\n",
        "  log_metric = vmap(log_prediction_metric)(qs)\n",
        "  return jnp.exp(log_metric.mean()),jnp.exp(log_metric.std()/jnp.sqrt(log_metric.shape[0]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PUta3mZuTkaQ"
      },
      "outputs": [],
      "source": [
        "noisetype='White'#@param ['White','Pink','Brown']\n",
        "noise = {'White':train.Identity,'Pink':train.PinkCovariance,'Brown':train.BrownianCovariance}[noisetype]\n",
        "difftype='VE'#@param ['VP','VE','SubVP','Test']\n",
        "diff = {'VP':train.VariancePreserving,'VE':train.VarianceExploding,\n",
        "        'SubVP':train.SubVariancePreserving,'Test':train.Test}[difftype](noise)\n",
        "epochs = 2000#@param {'type':'integer'}\n",
        "score_fn = train.train_diffusion(model,dataiter,epochs,diffusion=diff,lr=3e-4)\n",
        "key= jax.random.PRNGKey(38)\n",
        "nll = samplers.compute_nll(diff,score_fn,key,x).mean()\n",
        "stoch_samples = samplers.stochastic_sample(diff,score_fn,key,x[:30].shape,N=1000,traj=False)\n",
        "err = pmetric(stoch_samples)[0]\n",
        "print(f\"{noise.__name__} gets NLL {nll:.3f} and err {err:.3f}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nUUCduoEkaYh"
      },
      "outputs": [],
      "source": [
        "# noisetype='White'#@param ['White','Pink','Brown']\n",
        "# noise = {'White':train.Identity,'Pink':train.PinkCovariance,'Brown':train.BrownianCovariance}[noisetype]\n",
        "\n",
        "# difftype='VE'#@param ['VP','VE','SubVP','Test']\n",
        "# diff = {'VP':train.VariancePreserving,'VE':train.VarianceExploding,\n",
        "#         'SubVP':train.SubVariancePreserving,'Test':train.Test}[difftype](noise)\n",
        "# epochs = 2000#@param {'type':'integer'}\n",
        "\n",
        "#outputs= []\n",
        "# for noise in [train.WhiteCovariance]:#train.BrownianCovariance,train.PinkCovariance,train.WhiteCovariance]:\n",
        "#   for multiplier in [.1,.3,1,3,10,30]:\n",
        "#     noise.multiplier=multiplier\n",
        "#     diff = train.VarianceExploding(noise)\n",
        "#     score_fn = train.train_diffusion(model,dataiter,epochs,diffusion=diff,lr=3e-4)\n",
        "#     key= jax.random.PRNGKey(38)\n",
        "#     nll = samplers.compute_nll(diff,score_fn,key,x).mean()\n",
        "#     stoch_samples = samplers.stochastic_sample(diff,score_fn,key,x[:30].shape,N=1000,traj=False)\n",
        "#     err = pmetric(stoch_samples)[0]\n",
        "#     outputs.append(f\"{noise.__name__} with scale {multiplier} gets NLL {nll:.3f} and err {err:.3f}\")\n",
        "#     print(outputs[-1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y-P8uVpcWP-x"
      },
      "outputs": [],
      "source": [
        "print(outputs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NWv-JuOihAkd"
      },
      "outputs": [],
      "source": [
        "outputs2=outputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1wondoGVgzg9"
      },
      "outputs": [],
      "source": [
        "for output in outputs:\n",
        "  print(output)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fRn9pR8Ewzey"
      },
      "outputs": [],
      "source": [
        "importlib.reload(samplers)\n",
        "importlib.reload(train)\n",
        "#samplers.probability_flow(diff,score_fn,x,1e-4,1.).std()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P1VdapJf4Omo"
      },
      "outputs": [],
      "source": [
        "samples = noise.sample(jax.random.PRNGKey(39),x.shape)\n",
        "samples2 = jax.random.normal(jax.random.PRNGKey(39),x.shape)\n",
        "samples2 = jnp.cumsum(samples2,axis=1)\n",
        "half_x = x[:,:x.shape[1]//2]\n",
        "samples2 = jnp.concatenate([half_x,half_x[:,::-1]],axis=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TY5VdTAx4Xlz"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "i=1 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "plt.plot(ds.T_long,samples[i,:,-1].T,alpha=1/2,label='brown1')\n",
        "plt.plot(ds.T_long,samples2[i,:,-1].T,alpha=1/2,label='data')\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zJk34SSU5r64"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "i=5 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "fourier_mag1 = jnp.abs(jnp.fft.rfft(samples[::25,:,-1],axis=1))\n",
        "fourier_mag2 = jnp.abs(jnp.fft.rfft(samples2[::25,:,-1],axis=1))\n",
        "plt.plot(fourier_mag1.T,alpha=1/5,label='brown1',color='y')\n",
        "plt.plot(fourier_mag2.T,alpha=1/5,label='data',color='brown')\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "# plt.xlabel('Time t')\n",
        "# plt.ylabel(r'State')\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8nc8LmPLL_z4"
      },
      "source": [
        "Sample generation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Kboc2IIeteh7"
      },
      "outputs": [],
      "source": [
        "stoch_samples = samplers.stochastic_sample(diff,score_fn,key,x[:30].shape,N=1000,traj=False)\n",
        "sample_traj = samplers.stochastic_sample(diff,score_fn,key,x[:30].shape,N=1000,traj=True)\n",
        "det_samples = samplers.sample(diff,score_fn,key,x[:30].shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5rrpbT2MAWRn"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "i=6 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "plt.plot(ds.T_long,sample_traj[0::100,i,:,-1].T,alpha=1/2)\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "#plt.ylim(-5,5)\n",
        "#plt.legend([r'GT',r'Model'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-LerwGMP9Bla"
      },
      "outputs": [],
      "source": [
        "from jax import vmap\n",
        "n=sample_traj.shape[0]+1\n",
        "ts = (.5+jnp.arange(n)[::-1])[:-1]/n\n",
        "scores = vmap(score_fn)(sample_traj,ts).reshape(sample_traj.shape)\n",
        "best_reconstructions = (sample_traj+diff.sigma(ts)[:,None,None,None]**2*scores)/diff.scale(ts)[:,None,None,None]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xr1eZyO-1IWB"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import matplotlib as mpl\n",
        "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
        "\n",
        "i=4 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "\n",
        "cmap='inferno'\n",
        "\n",
        "\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = best_reconstructions[100::25,i,:,-1].T\n",
        "ax1.plot(ds.T_long,data[:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for i,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[i])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "#plt.ylim(-2,2)\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2ohSLTVT9f2M"
      },
      "outputs": [],
      "source": [
        "from scipy.ndimage import correlate1d\n",
        "i=22 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "vs = -correlate1d(best_reconstructions,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=2)\n",
        "print(vs.shape)\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = vs[100::25,i,:,-1].T\n",
        "ax1.plot(ds.T_long,data[:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for i,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[i])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'$\\dot \\theta$')\n",
        "#plt.ylim(-2,2)\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5UZzL2APAxpy"
      },
      "outputs": [],
      "source": [
        "i=12 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "vs = -correlate1d(best_reconstructions,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=2)\n",
        "z = ode_datasets.pack(best_reconstructions,(vmap(vmap(vmap(ds.M)))(best_reconstructions)@vs[...,None]).squeeze(-1))\n",
        "Hs = vmap(vmap(vmap(ds.H)))(z)\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = Hs[100::25,i,:].T\n",
        "ax1.plot(ds.T_long,data[:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for k,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[k])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'Energy')\n",
        "minn,maxx = data[:,-1].min(),data[:,-1].max()\n",
        "plt.ylim(minn-2*(maxx-minn),maxx+2*(maxx-minn))\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)\n",
        "z0 = z[-1,:,0]\n",
        "z_gts = vmap(ds.integrate,(0,None),0)(z0,ds.T_long)\n",
        "ax1.plot(ds.T_long,vmap(vmap(ds.H))(z_gts)[i],color='g',lw=3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fiq7MGQm-1x-"
      },
      "outputs": [],
      "source": [
        "i=15 # @param {type:\"slider\", min:0, max:30, step:1}\n",
        "nn = sample_traj.shape[2]\n",
        "fft = jnp.abs(np.fft.rfft(sample_traj,axis=2))#[:,:,:nn//2]\n",
        "freq = np.fft.rfftfreq(sample_traj.shape[2],d=(ds.T[1]-ds.T[0]))#[:nn//2]\n",
        "\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = fft[0::25,i,:,-1].T\n",
        "ax1.plot(freq,data[:,:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for i,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[i])\n",
        "plt.xlabel('Frequency f')\n",
        "plt.ylabel(r'Fourier spectrum')\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "#plt.ylim(-2,2)\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[0], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)\n",
        "ax1.plot(freq,jnp.abs(np.fft.rfft(x,axis=1))[::10,:,-1].T,color='blue',alpha=.1);"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GSCNVbsV-zlI"
      },
      "outputs": [],
      "source": [
        "i=8 # @param {type:\"slider\", min:0, max:30, step:1}\n",
        "nn = best_reconstructions.shape[2]\n",
        "fft = jnp.abs(np.fft.rfft(best_reconstructions,axis=2))#[:,:,:nn//2]\n",
        "freq = np.fft.rfftfreq(best_reconstructions.shape[2],d=(ds.T[1]-ds.T[0]))#[:nn//2]\n",
        "\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = fft[100::25,i,:,-1].T\n",
        "ax1.plot(freq,data[:,:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for i,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[i])\n",
        "plt.xlabel('Frequency f')\n",
        "plt.ylabel(r'Fourier spectrum')\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "#plt.ylim(-2,2)\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmax=ts[100], vmin=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)\n",
        "ax1.plot(freq,jnp.abs(np.fft.rfft(x,axis=1))[::10,:,-1].T,color='blue',alpha=.1);"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fa5_fj-mU_zw"
      },
      "outputs": [],
      "source": [
        "\n",
        "import matplotlib.pyplot as plt\n",
        "i=20 # @param {type:\"slider\", min:0, max:30, step:1}\n",
        "plt.plot(ds.T_long,x[i,:,-1])\n",
        "plt.plot(ds.T_long,det_samples[i,:,-1])\n",
        "plt.plot(ds.T_long,stoch_samples[i,:,-1])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.legend([r'GT',r'Model (ODE)', r'Model (SDE)'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QFHO0tDBL42D"
      },
      "source": [
        "Test ability to condition model on previous timesteps"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cH8P02jcfKRH"
      },
      "outputs": [],
      "source": [
        "from jax import grad,jit\n",
        "condition_amount = 13# @param {type:\"slider\", min:0, max:50, step:1}\n",
        "mb = x[:30,:]\n",
        "\n",
        "\n",
        "def inpainting_scores(diffusion,scorefn,observed_values,slc):\n",
        "  b,n,c = observed_values.shape\n",
        "  def conditioned_scores(xt,t):\n",
        "    unflat_xt = xt.reshape(b,-1,c)\n",
        "\n",
        "    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)\n",
        "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
        "    combined_score = unobserved_score.at[:,slc].set(observed_score)\n",
        "    return combined_score\n",
        "  return conditioned_scores\n",
        "\n",
        "def inpainting_scores2(diffusion,scorefn,observed_values,slc):\n",
        "  b,n,c = observed_values.shape\n",
        "  def conditioned_scores(xt,t):\n",
        "    unflat_xt = xt.reshape(b,-1,c)\n",
        "\n",
        "    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)\n",
        "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
        "    def constraint(xt):\n",
        "      one_step_xhat = (xt+diffusion.sigma(t)**2*scorefn(xt,t))/diffusion.scale(t)\n",
        "      return jnp.sum((one_step_xhat.reshape(b,-1,c)[:,slc]-observed_values)**2)\n",
        "    #unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*10/(diff.g2(t)/2)\n",
        "    unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*30*diff.scale(t)**2/diff.sigma(t)**2\n",
        "    combined_score = unobserved_score.at[:,slc].set(observed_score)\n",
        "    return combined_score#.reshape(-1)\n",
        "  return jit(conditioned_scores)\n",
        "\n",
        "slc = slice(condition_amount)\n",
        "conditioned_samples = samplers.stochastic_sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc),key,mb.shape,N=1000,traj=True)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IUMkKWJuAy2y"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "i=1 # @param {type:\"slider\", min:0, max:30, step:1}\n",
        "plt.plot(ds.T_long,conditioned_samples[-600::100,i,:,-1].T,zorder=0,alpha=.2)\n",
        "plt.plot(ds.T_long,conditioned_samples[-1,i,  :,-1].T,zorder=2,label='prediction')\n",
        "plt.plot(ds.T_long,x[i,:,-1],label='ground truth',alpha=1,zorder=99)\n",
        "plt.plot(ds.T_long[slc],x[i,slc,-1],label='conditioning',alpha=1,zorder=100,lw=3)\n",
        "\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.ylim(-3,3)\n",
        "plt.legend()\n",
        "#plt.legend([r'GT',r'Model'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "21fgx9UbckBK"
      },
      "outputs": [],
      "source": [
        "conditioned_sample = samplers.sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc),key,mb.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xKR36dO6FgwA"
      },
      "outputs": [],
      "source": [
        "from jax import jit,vmap,random\n",
        "\n",
        "@jit\n",
        "def rel_err(z1,z2):\n",
        "  return jnp.abs((jnp.abs(z1-z2)).sum(-1)/(jnp.abs(z1).sum(-1)*jnp.abs(z2).sum(-1)))\n",
        "\n",
        "gt = x[:30]\n",
        "for pred in [conditioned_samples[-1],conditioned_sample]:\n",
        "  clamped_errs = jax.lax.clamp(1e-5,rel_err(pred,gt),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(ds.T_long,rel_errs)\n",
        "  plt.fill_between(ds.T_long, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Prediction Error')\n",
        "plt.legend(['SDE completion','ODE completion'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3Xfy7NxAfQ8S"
      },
      "outputs": [],
      "source": [
        "i=23 # @param {type:\"slider\", min:0, max:29, step:1}\n",
        "plt.plot(ds.T_long,x[i,:,-1])\n",
        "plt.plot(ds.T_long[slc],x[i,slc,-1],lw=3)\n",
        "plt.plot(ds.T_long,conditioned_sample[i,:,-1])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.legend([r'GT','Conditioning',r'Model'])\n",
        "plt.ylim(-3,3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mMvSdxUUG9Op"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ovmzQCeoG_go"
      },
      "source": [
        "Energy Conditioning"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zYF_N85SHDZU"
      },
      "outputs": [],
      "source": [
        "from jax import grad,jit\n",
        "mb = x[:30,:]\n",
        "b,_,c = mb.shape\n",
        "data_std = thetas.std()\n",
        "def energy_conditioned_scores(diffusion,scorefn):\n",
        "  def conditioned_scores(xt,t):\n",
        "    unflat_xt = xt.reshape(b,-1,c)\n",
        "\n",
        "\n",
        "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
        "    def constraint(xt):\n",
        "      q = (xt+diffusion.sigma(t)**2*scorefn(xt,t))/diffusion.scale(t)\n",
        "      q = q.reshape(b,-1,c)\n",
        "      #v = -correlate1d(one_step_xhat,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=1)\n",
        "      v = -(q[:,:-2]-q[:,2:])/(2*(ds.T[1]-ds.T[0]))\n",
        "      z = ode_datasets.pack(q[:,1:-1],(vmap(vmap(ds.M))(q[:,1:-1])@v[...,None]).squeeze(-1))\n",
        "      Hs = vmap(vmap(ds.H))(z)\n",
        "      return jnp.sum(jnp.square(Hs-Hs.mean(axis=1,keepdims=True)))\n",
        "    #unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*10/(diff.g2(t)/2)\n",
        "    unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*(diff.scale(t)**2/diff.sigma(t)**2)*.5/(1+diff.sigma(t)**2/data_std**2)\n",
        "    return unobserved_score\n",
        "  return jit(conditioned_scores)\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OdqVaK4bOvfo"
      },
      "outputs": [],
      "source": [
        "from jax import vmap,grad,jit\n",
        "from scipy.ndimage import correlate1d\n",
        "\n",
        "energy_samples_traj = samplers.stochastic_sample(diff,energy_conditioned_scores(diff,score_fn),key,mb.shape,N=1000,traj=True)\n",
        "n=energy_samples_traj.shape[0]+1\n",
        "ts = (.5+jnp.arange(n)[::-1])[:-1]/n\n",
        "scores = vmap(score_fn)(energy_samples_traj,ts).reshape(energy_samples_traj.shape)\n",
        "best_reconstructions = (energy_samples_traj+diff.sigma(ts)[:,None,None,None]**2*scores)/diff.scale(ts)[:,None,None,None]\n",
        "vs = -correlate1d(best_reconstructions,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=2)\n",
        "z = ode_datasets.pack(best_reconstructions,(vmap(vmap(vmap(ds.M)))(best_reconstructions)@vs[...,None]).squeeze(-1))\n",
        "Hs = vmap(vmap(vmap(ds.H)))(z)\n",
        "kstart=10\n",
        "z0 = z[-1,:,kstart]\n",
        "#print(vmap(ds.H)(z0)[i])\n",
        "z_gts = vmap(ds.integrate,(0,None),0)(z0,ds.T_long[kstart:])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RDhUt4f-vAfM"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9DTppttzKL7k"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import matplotlib as mpl\n",
        "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
        "\n",
        "i=14 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "\n",
        "cmap='inferno'\n",
        "\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = Hs[100::25,i,:].T\n",
        "#print(data[:,-1])\n",
        "ax1.plot(ds.T_long,data[::-1],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for k,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[k])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'Energy')\n",
        "minn,maxx = data[:,-1].min(),data[:,-1].max()\n",
        "plt.ylim(minn-2*(maxx-minn),maxx+2*(maxx-minn))\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)\n",
        "\n",
        "ax1.plot(ds.T_long[kstart:],vmap(vmap(ds.H))(z_gts)[i],color='g',lw=3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mc0IHMpaDrMR"
      },
      "outputs": [],
      "source": [
        "from jax import jit,vmap\n",
        "@jit\n",
        "def rel_err(x,y):\n",
        "  return  jnp.abs(x-y).sum(-1)/(jnp.abs(x).sum(-1)+jnp.abs(y).sum(-1))\n",
        "\n",
        "\n",
        "kstart=10\n",
        "@jit\n",
        "def log_prediction_metric(qs):\n",
        "  k=kstart\n",
        "  q = qs[k:]\n",
        "  v = -(q[:-2]-q[2:])/(2*(ds.T[1]-ds.T[0]))\n",
        "  #print(vmap(ds.M)(q[1:-1]).shape,v.shape)\n",
        "  z = ode_datasets.pack(q[1:-1],(vmap(ds.M)(q[1:-1])@v[...,None]).squeeze(-1))\n",
        "  T = ds.T_long[k+1:-1]\n",
        "  z_gt = ds.integrate(z[0],T)\n",
        "  return jnp.log(rel_err(z,z_gt)[1:len(T)//3]).mean()\n",
        "\n",
        "@jit\n",
        "def pmetric(qs):\n",
        "  log_metric = vmap(log_prediction_metric)(qs)\n",
        "  return jnp.exp(log_metric.mean()),jnp.exp(log_metric.std()/jnp.sqrt(log_metric.shape[0]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XKRowUEwJ7ea"
      },
      "outputs": [],
      "source": [
        "energy_samples_stoch = samplers.stochastic_sample(diff,energy_conditioned_scores(diff,score_fn),key,mb.shape,N=2000,traj=False)\n",
        "print(f'SDE performance {pmetric(energy_samples_stoch)[0]}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FM9Z3uXvtgnt"
      },
      "outputs": [],
      "source": [
        "energy_samples_det = samplers.sample(diff,energy_conditioned_scores(diff,score_fn),key,mb.shape)\n",
        "print(f'ODE performance {pmetric(energy_samples_det)[0]}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fAyaiDRrMDd5"
      },
      "source": [
        "Unconditional Prediction quality"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ec8x-pzpwIUg"
      },
      "outputs": [],
      "source": [
        "stoch_samples = samplers.stochastic_sample(diff,score_fn,key,x[:30].shape,N=1000,traj=False)\n",
        "det_samples = samplers.sample(diff,score_fn,key,x[:30].shape)\n",
        "print(f'ODE performance {pmetric(det_samples)[0]}')\n",
        "print(f'SDE performance {pmetric(stoch_samples)[0]}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fg2JsQacJrBV"
      },
      "outputs": [],
      "source": [
        "from jax import random\n",
        "key = random.PRNGKey(45)\n",
        "#s=s2#,history = samplers.stochastic_sampler(denoiser,params,key,(32,)+data.shape[1:],N=500,smin=sigma_min,smax=sigma_max)\n",
        "s = stoch_samples#energy_samples_det#stoch_samples\n",
        "\n",
        "k = 5\n",
        "q = s[:,k:]\n",
        "v = -(q[:,:-2]-q[:,2:])/(2*(ds.T[1]-ds.T[0]))\n",
        "z = ode_datasets.pack(q[:,1:-1],(vmap(vmap(ds.M))(q[:,1:-1])@v[...,None]).squeeze(-1))\n",
        "T = ds.T_long[k+1:-1]\n",
        "z0 = z[:,0]\n",
        "z_gts = vmap(ds.integrate,(0,None),0)(z0,T)\n",
        "z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape),T)\n",
        "z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b-DH4XmZMwWX"
      },
      "outputs": [],
      "source": [
        "for pred in [z,z_pert,z_random]:\n",
        "  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(T,rel_errs)\n",
        "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Prediction Error')\n",
        "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gxOy42H5MyDV"
      },
      "outputs": [],
      "source": [
        "H_gts = vmap(vmap(ds.H))(z_gts)\n",
        "for pred in [z,z_pert,z_random]:\n",
        "  Hs = vmap(vmap(ds.H))(pred)\n",
        "  clamped_errs = jax.lax.clamp(1e-3,jnp.abs(Hs-H_gts)/jnp.abs(Hs*H_gts),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(T,rel_errs)\n",
        "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Energy Error')\n",
        "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5tQ0FJA2M94F"
      },
      "source": [
        "Compared trajectories"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SHKOgFXNJxY6"
      },
      "outputs": [],
      "source": [
        "for i in range(10):\n",
        "  fig = plt.figure()\n",
        "  ax = fig.add_subplot(1, 1, 1)\n",
        "  line1, = ax.plot(T,z_gts[i,:,0])\n",
        "  line2, = ax.plot(T,z[i,:,0])\n",
        "  line3, = ax.plot(T,z_pert[i,:,0])\n",
        "  plt.xlabel('Time t')\n",
        "  plt.ylabel(r'State')\n",
        "  plt.legend(['gt','model','pert'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KEMilLlTJyzM"
      },
      "outputs": [],
      "source": [
        "for i in range(10):\n",
        "  fig = plt.figure()\n",
        "  ax = fig.add_subplot(1, 1, 1)\n",
        "  line1, = ax.plot(T,z_gts[i,:,0])\n",
        "  line2, = ax.plot(T,z[i,:,0])\n",
        "  line3, = ax.plot(T,z_gts[i,:,-1])\n",
        "  line5, = ax.plot(T,z[i,:,-1])\n",
        "  plt.xlabel('Time t')\n",
        "  plt.ylabel(r'State')\n",
        "  plt.legend([r'$\\theta_0$ gt',r'$\\theta_0$ model',r'v gt', r'v model'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KYr49SCypJTT"
      },
      "outputs": [],
      "source": [
        "ds.animate()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1tEXOhfcMmho"
      },
      "outputs": [],
      "source": [
        "metric_vals =[]\n",
        "metric_stds = []\n",
        "Ns = [25,50,100,200,500,1000,2000]\n",
        "for N in Ns:\n",
        "  s = samplers.stochastic_sample(diff,score_fn,key,x[:30].shape,N=N)\n",
        "  mean,std = pmetric(s)\n",
        "  metric_vals.append(mean)\n",
        "  metric_stds.append(std)\n",
        "metric_vals = np.array(metric_vals)\n",
        "metric_stds = np.array(metric_stds)\n",
        "\n",
        "plt.plot(Ns,metric_vals)\n",
        "plt.fill_between(Ns, metric_vals/metric_stds, metric_vals*metric_stds,alpha=.3)\n",
        "plt.xlabel('Sampler steps')\n",
        "plt.ylabel('Pmetric value')\n",
        "plt.xscale('log')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qt2vnvKFoBuo"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D467_GYqB0Vo"
      },
      "outputs": [],
      "source": [
        "ds2 = ode_datasets.NPendulum(N=50,n=2,dt=.002)\n",
        "thetas2,vs2 = ode_datasets.unpack(ds2.Zs)\n",
        "test_x2 = thetas2\n",
        "#test_x2 = -correlate1d(test_x2,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=1)\n",
        "freq = np.fft.rfftfreq(test_x2.shape[1],d=(ds2.T[1]-ds2.T[0]))\n",
        "ffts  =jnp.abs(np.fft.rfft(test_x2,axis=1))\n",
        "avg_freq  =jnp.exp(jnp.log(ffts)[:,:,-1]).mean(0)\n",
        "plt.plot(freq,avg_freq.T,color='blue',label=r'2 Pendulum $\\theta$');\n",
        "plt.plot(freq,100/freq,label=r'1/f')\n",
        "plt.plot(freq,30/freq**2,label=r'1/$f^2$')\n",
        "plt.plot(freq,10/freq**3,label=r'1/$f^3$')\n",
        "#plt.plot(freq,np.exp(-10*freq**2),label=r'$e^{-f^2}$')\n",
        "plt.plot(freq,ffts[:,:,-1].T,color='y',alpha=.2);\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "plt.ylim(1e-3,1e5)\n",
        "plt.xlabel('Frequency f')\n",
        "plt.ylabel(r'Fourier spectrum')\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wLgvOU60M5lt"
      },
      "outputs": [],
      "source": [
        "z = noise.sample(jax.random.PRNGKey(38),x.shape)\n",
        "freq = np.fft.rfftfreq(z.shape[1],d=(ds.T[1]-ds.T[0]))\n",
        "ffts  =jnp.abs(np.fft.rfft(z,axis=1))\n",
        "avg_freq  =jnp.exp(jnp.log(ffts)[:,:,-1]).mean(0)\n",
        "plt.plot(freq,avg_freq.T,color='blue',label=r'2 Pendulum $\\theta$');\n",
        "plt.plot(freq,100/freq,label=r'1/f')\n",
        "plt.plot(freq,30/freq**2,label=r'1/$f^2$')\n",
        "plt.plot(freq,10/freq**3,label=r'1/$f^3$')\n",
        "#plt.plot(freq,np.exp(-10*freq**2),label=r'$e^{-f^2}$')\n",
        "plt.plot(freq,ffts[:,:,-1].T,color='y',alpha=.2);\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "plt.ylim(1e-3,1e5)\n",
        "plt.xlabel('Frequency f')\n",
        "plt.ylabel(r'Fourier spectrum')\n",
        "plt.legend()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu",
        "kind": "private"
      },
      "name": "train_simplified.ipynb",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
